import os
import random
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

from utils.util import read_json
from dsl_design.src.segmentation import Segmentation
from dsl_design.src.extract import Extract
from dsl_design.src.recognize import Recognize
from dsl_design.src.annotation import *
from dsl_design.src.feature import Feature
from dsl_design.src.operation import Operation
from dsl_design.src.production import Production

name_mapping = {
    "Genetics": "Molecular Biology & Genetics",
    "Medical": "Biomedical & Clinical Research",
    "Ecology": "Ecology & Environmental Biology",
    "BioEng": "Bioengineering & Technology",
}

def autodsl(args):

    domain_name = "demo" if args.demo else name_mapping.get(args.domain, "demo")

    if args.mode == "segmentation":
        if args.demo:
            filenames = os.listdir(args.origin)
            valid_protocols = []
            while len(valid_protocols) < 3:
                file = random.choice(filenames)
                protocol = read_json(os.path.join(args.domain, file))
                if "Bioinformatics & Computational Biology" not in protocol.get("bigAreas", []):
                    valid_protocols.append(protocol["procedures"])
                filenames.remove(file)
            protocols = [procedure for protocol in valid_protocols for procedure in protocol]
        else:
            protocols = [
                procedure 
                for filename in os.listdir(args.origin)
                for protocol in [read_json(os.path.join(args.origin, filename))]
                if args.domain in protocol.get("bigAreas", [])
                for procedure in protocol.get("procedures", [])
            ]
        segmentation = Segmentation(protocols)
        output_path = os.path.join(args.data, domain_name, "sentences.json")
        segmentation.segmentation(output_path)
    
    elif args.mode == "extract":
        extracter = Extract(
            input_data_path=os.path.join(args.data, domain_name, "sentences.json"),
            store_path=os.path.join(args.data, domain_name, "extracted.json")
        )
        # extracter.extract()
        # filter opcode
        extracter.filter_extract_data(threshold=300)
        # operation superclass annotation
        oa = OperationAnnotation(
            input_data_path=os.path.join(args.data, domain_name, "extracted_filtered.json")
        )
        oa.annotate_with_context(store_path=os.path.join(args.data, domain_name, "superclass_to_opcode.json"))
        # sample sentences
        extracter.sample_sentences(threshold=500)

    elif args.mode == "recognize":
        # recognizer = Recognize(
        #     extracted_data_path=os.path.join(args.data, domain_name, "extracted_filtered.json"),
        #     recognized_store_path=os.path.join(args.data, domain_name, "recognized.json")
        # )
        # recognizer.recognize_whole()
        # component superclass annotation, change the input recognized data
        ca = ComponentAnnotation(
            input_data_path=os.path.join(args.data, domain_name, "recognized.json")
        )
        ca.annotate(store_path=os.path.join(args.data, domain_name, "superclass_to_flowunit.json"))
        # merge same device, won't change the input recognized data
        aj_device = AliasJudgement(
            entity="device",
            recognized_data_path=os.path.join(args.data, domain_name, "recognized.json"),
            same_entities_store_path=os.path.join(args.data, domain_name, "same_devices.json")
        )
        aj_device.annotate()
        # merge same component, won't change the input recognized data
        # aj_component = AliasJudgement(
        #     entity="component",
        #     recognized_data_path=os.path.join(args.data, domain_name, "recognized.json"),
        #     same_entities_store_path=os.path.join(args.data, domain_name, "same_components.json")
        # )
        # aj_component.annotate()

    elif args.mode == "operation":
        feature = Feature(
            recognized_data_path=os.path.join(args.data, domain_name, "recognized.json"), 
            feature_data_path=os.path.join(args.data, domain_name, "feature_data.json"),
            same_devices=read_json(os.path.join(args.data, domain_name, "same_devices.json"))
        )
        operation = Operation(
            feature=feature, 
            operation_dsl_path=os.path.join(args.data, domain_name, "operation_dsl.json")
        )

        def process_opcode(opcode):
            idx_list_h1, value_list_h1 = feature.feature_vector_extraction(opcode=opcode, hierarchy=1)
            operation.recursive_clustering(opcode, idx_list_h1, value_list_h1, hierarchy=1)
            operation.analyse(opcode)
            return opcode

        # 使用 ThreadPoolExecutor 并行化处理
        with ThreadPoolExecutor() as executor:
            futures = [executor.submit(process_opcode, opcode) for opcode in feature.feature_data]
            
            # 使用 tqdm 来显示进度
            for future in tqdm(as_completed(futures), total=len(futures), desc="Abstraction"):
                future.result()  # 获取结果，确保异常被捕获
        
        operation.dump_result()
    
    elif args.mode == "production":
        # same_components = read_json(os.path.join(args.data, domain_name, "same_components.json"))
        production = Production(
            recognized_data_path=os.path.join(args.data, domain_name, "recognized.json"), 
            operation_dsl_path=os.path.join("dsl_result", args.domain, "operation_dsl.json"), 
            same_components={}
        )
        # production.extract(os.path.join(args.data, domain_name, "production_dsl.json"))
        # production.cluster(domain=args.domain, product_dsl_path=f"dsl_result/{args.domain}/product_dsl.json")
        production.operation_simple_clustering(domain=args.domain)
    
    else:
        print("Wrong Mode!")